%% Aim
% Predict choice and rt based on DDM estimates

%% Setup
clc;
clear;
tic;
root_path = pwd; 
addpath(genpath(root_path));

nrole=2;
nprice=10;
nlottery=10;

role_label={'Buyer','Seller'};
role_color_dark={[235/255 112/255 99/255],[106/255 142/255 208/255]}; 
role_color_median={[238/255 137/255 126/255],[144/255 171/255 220/255]}; 
role_color_light={[243/255 171/255 163/255],[165/255 187/255 227/255]};

bias_label={'Val','Res','Val vs. Res'};
bias_color_median={[195/255 205/255 147/255],[192/255 173/255 203/255]}; 

%% Load data
sub_data=readtable([root_path,'\data\subject\subject.csv']);
nsub=height(sub_data);

beh_data=readtable([root_path,'\data\behavior\behavior.csv']);
gaze_trial_data=readtable([root_path,'\data\gaze\gaze_trial.csv']);
eye_trial_data=readtable([root_path,'\data\eye\eye_trial.csv']);
trial_data = [beh_data, eye_trial_data(:,9:end), gaze_trial_data(:,9:end)];
trial_data=trial_data(trial_data.ValidGaze==1 & trial_data.GazeNum>0,:);

ahddm_data=readtable([root_path,'\model\hddm\aDDM4a\aDDM4a_hddm_rt_simulated.csv']);

%% Incorporate date

trial_data.aDDMRTSim=ahddm_data.RTSim;

trial_data.aDDMRTPred = abs(trial_data.aDDMRTSim);
trial_data.aDDMEndingPred = trial_data.aDDMRTSim>0;
trial_data.aDDMChoicePred = (2-trial_data.Role) .* trial_data.aDDMEndingPred + (trial_data.Role-1) .* (1-trial_data.aDDMEndingPred);

%% Predicted Probability of trading: Matrix figure (group-level)

    for irole=1:nrole
        for iprice=1:nprice
            for ilottery=1:nlottery
                tdata=trial_data(  trial_data.Role==irole & ...
                                 trial_data.Price==iprice & ...
                                 trial_data.Lottery==ilottery*2,:);
                prob_accept.matrix.group.role(irole).value(ilottery,iprice)=nanmean(tdata.aDDMChoicePred);
            end
        end
    end

    figure('Renderer', 'painters', 'Position', [10 10 1000 350]);
    for irole=1:nrole
        subplot(1,2,irole);
        A=prob_accept.matrix.group.role(irole).value;
        h=imagesc(A);
        cmap_up=[linspace(1,192/256,1000)',linspace(1,0,1000)',linspace(1,0,1000)']; 
        cmap_down=[linspace(0/256,1,1000)',linspace(32/256,1,1000)',linspace(102/256,1,1000)']; 
        cmap=[cmap_down;cmap_up];
        colormap(cmap);
        set(gca,'TickDir','out');
        xlabel('Price');
        ylabel('Lottery');
        yticks(1:nprice);
        xticks(1:nlottery);
        yticklabels({'2','4','6','8','10','12','14','16','18','20'});
        xticklabels({'1','2','3','4','5','6','7','8','9','10'});
        colorbar;
        caxis([0 1]);
        title(role_label{irole});
        hold on
    end
    hold off
    print(1, '-dtiff', [root_path, '\figure\behavior\choice\matrix\prob_accept_matrix_group_pred_addm.tif'], '-r200');
    close;       

%% Raw Choice vs. Predicted Choice
for isub=1:height(sub_data)
    tdata=trial_data(trial_data.Subject==isub,:);
    sub_data.aDDMChoicePred(isub)=mean(tdata.Choice==tdata.aDDMChoicePred); 
end   

mean(sub_data.aDDMChoicePred);
std(sub_data.aDDMChoicePred);
min(sub_data.aDDMChoicePred);
max(sub_data.aDDMChoicePred);

%% Predicted RT: Matrix figure (group-level)

    for irole=1:nrole
        for iprice=1:nprice
            for ilottery=1:nlottery
                tdata=trial_data(  trial_data.Role==irole & ...
                                 trial_data.Price==iprice & ...
                                 trial_data.Lottery==ilottery*2,:);
                rt.matrix.group.role(irole).value(ilottery,iprice)=nanmean(tdata.aDDMRTPred);
            end
        end
    end

    figure('Renderer', 'painters', 'Position', [10 10 1000 350]);
    for irole=1:nrole
        subplot(1,2,irole);
        A=rt.matrix.group.role(irole).value;
        h=imagesc(A);
        cmap_up=[linspace(1,192/256,1000)',linspace(1,0,1000)',linspace(1,0,1000)']; 
        cmap_down=[linspace(0/256,1,1000)',linspace(32/256,1,1000)',linspace(102/256,1,1000)']; 
        cmap=[cmap_down;cmap_up];
        colormap(cmap);
        set(gca,'TickDir','out');
        xlabel('Price');
        ylabel('Lottery');
        yticks(1:nprice);
        xticks(1:nlottery);
        yticklabels({'2','4','6','8','10','12','14','16','18','20'});
        xticklabels({'1','2','3','4','5','6','7','8','9','10'});
        colorbar;
        if irole==1
            caxis([0.9 1.2]);
        else
            caxis([0.9 1.4]);
        end
        title(role_label{irole});
        hold on
    end
    hold off
    print(1, '-dtiff', [root_path, '\figure\behavior\rt\matrix\rt_matrix_group_pred_addm.tif'], '-r200');
    close;       

%% Raw RT vs. Predicted RT
for isub=1:height(sub_data)
    tdata=trial_data(trial_data.Subject==isub,:);
    [sub_data.aDDMRTPred(isub),p]=corr(tdata.RT, tdata.aDDMRTPred); 
end   

mean(sub_data.aDDMRTPred);
std(sub_data.aDDMRTPred);
min(sub_data.aDDMRTPred);
max(sub_data.aDDMRTPred);

[h,p,ci,stats]=ttest(sub_data.aDDMRTPred,0,'tail','right');

%% save
writetable(sub_data,[root_path,'\data\subject\subject_addm_pred.csv']);

%%
toc;
